package smcrypto

import (
	"encoding/asn1"
	"errors"
	"fmt"
	"github.com/ethereum/go-ethereum/sm2"
	"github.com/ethereum/go-ethereum/sm3"
	"math/big"
)

const SignatureLength = 96

type Signature struct {
	R, S *big.Int
	V    *big.Int
}

func MarshalSign(r, s, v *big.Int) ([]byte, error) {
	result, err := asn1.Marshal(Signature{r, s, v})
	if err != nil {
		return nil, err
	}
	return result, nil
}

func UnmarshalSign(sign []byte) (r, s, v *big.Int, err error) {
	sm2Sign := new(Signature)
	_, err = asn1.Unmarshal(sign, sm2Sign)
	if err != nil {
		return nil, nil, nil, err
	}
	return sm2Sign.R, sm2Sign.S, sm2Sign.V, nil
}

func SignSerialize(r, s, v *big.Int) ([]byte, error) {
	sig := make([]byte, SignatureLength)

	rbytes := r.Bytes()
	sbytes := s.Bytes()
	vbytes := v.Bytes()

	copy(sig[32-len(rbytes):32], rbytes)
	copy(sig[64-len(sbytes):64], sbytes)
	copy(sig[96-len(vbytes):96], vbytes)

	return sig, nil
}

func SignUnSerialize(sign []byte) (r, s, v *big.Int, err error) {
	if len(sign) != SignatureLength {
		return nil, nil, nil, errors.New("length is not valid")
	}
	R := new(big.Int).SetBytes(sign[:32])
	S := new(big.Int).SetBytes(sign[32:64])
	V := new(big.Int).SetBytes(sign[64:96])
	return R, S, V, nil
}

// decompressPoint decompresses a point on the given curve given the X point and
// the solution to use.
func decompressPoint(curve sm2.P256V1Curve, x *big.Int, ybit bool) (*big.Int, error) {
	// TODO: This will probably only work for secp256k1 due to
	// optimizations.
	// Y = +-sqrt(x^3 + B)
	x3 := new(big.Int).Mul(x, x)
	x3.Mul(x3, x)
	x3.Add(x3, curve.Params().B)

	// Y = +-sqrt(x^3 + ax + B)
	//var a, ax, x_ sm2P256FieldElement
	//sm2P256FromBig(&a, curve.A)
	//sm2P256FromBig(&x_, x)
	//sm2P256Mul(&ax, &a, &x_) // a = a * x
	//x3.Add(x3, sm2P256ToBig(&ax))
	ax := new(big.Int).Mul(x, curve.A)

	x3.Add(x3, ax)

	// now calculate sqrt mod p of x2 + B
	// This code used to do a full sqrt based on tonelli/shanks,
	// but this was replaced by the algorithms referenced in
	// https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294
	y := new(big.Int).Exp(x3, QPlus1Div4(curve), curve.Params().P)

	if ybit != isOdd(y) {
		y.Sub(curve.Params().P, y)
	}
	if ybit != isOdd(y) {
		return nil, fmt.Errorf("ybit doesn't match oddness")
	}
	return y, nil
}

// 恢复公钥
func recoverPublicKey(curve sm2.P256V1Curve, sig *Signature, msg []byte, iter int, doChecks bool) (*sm2.PublicKey, error) {
	// 1.1 x = (n * i) + r - e
	Rx := new(big.Int).Mul(curve.Params().N, new(big.Int).SetInt64(int64(iter/2)))
	Rx.Add(Rx, sig.R)
	Rx.Sub(Rx, sig.V)

	if Rx.Cmp(curve.Params().P) != -1 {
		return nil, errors.New("calculated Rx is larger than curve P")
	}

	// convert 02<Rx> to point R. (step 1.2 and 1.3). If we are on an odd
	// iteration then 1.6 will be done with -R, so we calculate the other
	// term when uncompressing the point.
	Ry, err := decompressPoint(curve, Rx, iter%2 == 1)
	if err != nil {
		return nil, err
	}
	// 1.4 Check n*R is point at infinity
	if doChecks {
		nRx, nRy := curve.ScalarMult(Rx, Ry, curve.Params().N.Bytes())
		if nRx.Sign() != 0 || nRy.Sign() != 0 {
			return nil, errors.New("n*R does not equal the point at infinity")
		}
	}

	// 1.5 calculate e from message using the same algorithm as ecdsa
	// signature calculation.
	// e := hashToInt(msg, curve)

	// Step 1.6.1:
	// We calculate the two terms sR and eG separately multiplied by the
	// inverse of r (from the signature). We then add them to calculate
	// Q = r^-1(sR-eG)
	// Q = (s+r)^-1(R-sG)
	invr := new(big.Int).ModInverse(new(big.Int).Add(sig.S, sig.R), curve.Params().N)
	// first term.
	// invrS := new(big.Int).Mul(invr, sig.S)
	// invrS.Mod(invrS, curve.Params().N)
	sRx, sRy := curve.ScalarMult(Rx, Ry, invr.Bytes())
	s := new(big.Int).Set(sig.S)
	// second term.
	s.Neg(s)
	s.Mod(s, curve.Params().N)
	s.Mul(s, invr)
	s.Mod(s, curve.Params().N)
	minuseGx, minuseGy := curve.ScalarBaseMult(s.Bytes())

	// TODO: this would be faster if we did a mult and add in one
	// step to prevent the jacobian conversion back and forth.
	Qx, Qy := curve.Add(sRx, sRy, minuseGx, minuseGy)
	return &sm2.PublicKey{
		Curve: curve,
		X:     Qx,
		Y:     Qy,
	}, nil
}

// 解析公钥
func ParsePublicKey(curve sm2.P256V1Curve, sign *Signature, msg []byte) (*sm2.PublicKey, error) {
	for i := 0; i < 4; i++ {
		key, err := recoverPublicKey(curve, sign, msg, i, true)
		if err != nil {
			// return nil, false, err--------------------------------------------
			//fmt.Println("e:", err)
		} else {
			// check e
			digest := sm3.New()
			//digest hash.Hash, curve *P256V1Curve, pubX *big.Int, pubY *big.Int, userId []byte
			//fmt.Println(key)
			v := sm2.CalculateE(digest, &curve, key.X, key.Y, sm2SignDefaultUserId, msg)
			if v.Cmp(sign.V) == 0 {
				return key, nil
			}
		}
	}

	return nil, errors.New("parse public key error")
}

func QPlus1Div4(curve sm2.P256V1Curve) *big.Int {
	return new(big.Int).Div(new(big.Int).Add(curve.P, big.NewInt(1)), big.NewInt(4))
}

func isOdd(a *big.Int) bool {
	return a.Bit(0) == 1
}
